package com.ml4ai.demo;

import com.ml4ai.nn.*;
import com.ml4ai.nn.core.Variable;
import com.ml4ai.nn.core.optimizers.Moment;
import com.ml4ai.nn.core.optimizers.NNOptimizer;
import lombok.SneakyThrows;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/**
 * 生成式对抗网络
 */
public class GenerativeAdversarialNet {

    public static class GAN extends BaseForwardNetwork {

        private ForwardNetwork generator;
        private ForwardNetwork discriminator;

        public GAN() {
            /**
             * 生成器定义
             */
            generator = new SequentialForward(
                    new Linear(28 * 28, 256),
                    new LeakyRelu(),
                    new Linear(256, 128),
                    new LeakyRelu(),
                    new Linear(128, 1),
                    new Sigmoid()
            );
            /**
             * 判别器
             */
            discriminator = new SequentialForward(
                    new Linear(3, 128),
                    new LeakyRelu(),
                    new Linear(128, 256),
                    new LeakyRelu(),
                    new Linear(256, 28 * 28)
            );
            add(generator);
            add(discriminator);
        }

        @Override
        public Variable[] getParameters() {
            return super.getParameters();
        }

        public Variable[] getDiscriminatorParameters() {
            return discriminator.getParameters();
        }

        public Variable[] getGeneratorParameters() {
            return generator.getParameters();
        }
    }

    @SneakyThrows
    public static void main(String[] argv) {
        GAN gan = new GAN();
        NNOptimizer generatorOptimizer = new Moment(gan.getGeneratorParameters(), 1e-3, 0.95D);
        NNOptimizer discriminatorOptimizer = new Moment(gan.getDiscriminatorParameters(), 1e-3, 0.95D);
        int batch_size = 5;
        DataSetIterator dataSetIt = new MnistDataSetIterator(batch_size, 1);
        for (int i = 0; i < 10000; i++) {

        }
    }

}

